/*
 * Copyright (c) Huawei Technologies Co., Ltd. 2024-2024. All rights reserved.
 */
#ifndef FTN_IMPL_HELPER_AVX2_H
#define FTN_IMPL_HELPER_AVX2_H

template <bool kIsEquality, typename DistT>
SCANN_SIMD_INLINE size_t CalculateSwapMasks(
    const DistT *values, uint32_t *masks, size_t n_masks, uint32_t final_mask, DistT threshold)
{
    const auto simd_threshold = Simd<DistT>::Broadcast(threshold);

    size_t n_kept = 0;
    for (size_t j : Seq(n_masks)) {
        auto vals = SimdFor<DistT, 32>::Load(values + 32 * j);
        const uint32_t mask =
            kIsEquality ? GetComparisonMask(vals == simd_threshold) : GetComparisonMask(vals < simd_threshold);
        n_kept += absl::popcount(mask);
        masks[j] = mask;
    }

    uint32_t &last_mask = masks[n_masks - 1];
    n_kept -= absl::popcount(last_mask);
    last_mask &= final_mask;
    n_kept += absl::popcount(last_mask);

    return n_kept;
}

#ifdef SCANN_SVE
template <typename DistT, typename DatapointIndexT>
SCANN_SIMD_OUTLINE size_t UseMasksToCompactDoublePortedHelper(
    DatapointIndexT* indices, DistT* values, uint32_t* masks, size_t n_masks, DistT threshold_value)
{
    DCHECK_GE(n_masks, 2);

    std::copy(values, values + 64, values + n_masks * 32);
    std::copy(indices, indices + 64, indices + n_masks * 32);
    std::copy(masks, masks + 2, masks + n_masks);
    n_masks += 2;

    uint32_t mask1 = masks[2];
    DatapointIndexT* indices1 = indices + 2 * 32;
    DistT* values1 = values + 2 * 32;

    uint32_t mask2 = masks[3];
    DatapointIndexT* indices2 = indices + 3 * 32;
    DistT* values2 = values + 3 * 32;

    uint32_t* masks_ptr = masks + 3;
    uint32_t* masks_end = masks + n_masks;

    DatapointIndexT* indices_write_ptr = indices;
    DistT* values_write_ptr = values;

    for (;;) {
    if (ABSL_PREDICT_FALSE(!mask1 || !mask2)) {
        bool proceed_to_cooldown = false;

        do {
        if (!mask1) {
            mask1 = mask2;
            indices1 = indices2;
            values1 = values2;
        }

        if (++masks_ptr >= masks_end) {
            proceed_to_cooldown = true;
            break;
        }

        mask2 = *masks_ptr;
        indices2 += 32;
        values2 += 32;

        } while (ABSL_PREDICT_FALSE(!mask1 || !mask2));

        if (proceed_to_cooldown) break;
    }
    DCHECK(mask1);
    DCHECK(mask2);
    DCHECK_LT(indices_write_ptr, indices1);
    DCHECK_LT(indices_write_ptr, indices2);
    DCHECK_LT(values_write_ptr, values1);
    DCHECK_LT(values_write_ptr, values2);

    const int offset2 = bits::FindLSBSetNonZero(mask2);
    const int offset1 = bits::FindLSBSetNonZero(mask1);

    *indices_write_ptr++ = indices2[offset2];
    *values_write_ptr++ = values2[offset2];

    *indices_write_ptr++ = indices1[offset1];
    *values_write_ptr++ = values1[offset1];

    mask2 &= (mask2 - 1);
    mask1 &= (mask1 - 1);
    }

    while (mask1) {
    const int offset1 = bits::FindLSBSetNonZero(mask1);
    mask1 &= (mask1 - 1);
    *indices_write_ptr++ = indices1[offset1];
    *values_write_ptr++ = values1[offset1];
    }

    DCHECK_EQ(indices_write_ptr - indices, values_write_ptr - values);
    return indices_write_ptr - indices;
}

template<>
size_t UseMasksToCompactDoublePortedHelper<float, unsigned int>(
    unsigned int* indices, float* values, uint32_t* masks, size_t n_masks, float threshold_value)
{
    while (n_masks && masks[n_masks - 1] == 0) {
    n_masks--;
    }
    DCHECK_GE(n_masks, 2);
    size_t dataSize = n_masks * 32 - __builtin_clz(masks[n_masks - 1]);

    uint32_t *left_ptr_data = indices;
    uint32_t *left_ptr_str_data = indices;
    float *left_ptr_values = values;
    float *left_ptr_str_values = values;
    size_t vec_length = svcntb() / sizeof(int32_t);
    uint32_t *end_ptr = indices + dataSize - vec_length;

    // Vectorized partitioning loop
    while (left_ptr_data < end_ptr) {
    // Load data elements into an SVE vector
    svuint32_t vec_data = svld1(svptrue_b32(), left_ptr_data);
    svfloat32_t vec_values = svld1(svptrue_b32(), left_ptr_values);

    // Compare elements with threshold_value, generating a predicate for values < threshold_value
    svbool_t less_mask = svcmplt_n_f32(svptrue_b32(), vec_values, threshold_value);

    // Compact the elements < threshold_value to the beginning of the vector
    svuint32_t compacted_data_less_equal = svcompact(less_mask, vec_data);
    svfloat32_t compacted_values_less_equal = svcompact(less_mask, vec_values);

    // Store compacted < threshold_value elements at the left pointer
    size_t less_equal_count = svcntp_b32(svptrue_b32(), less_mask); // Count elements < threshold_value
    svst1(svptrue_b32(), left_ptr_str_data, compacted_data_less_equal);
    svst1(svptrue_b32(), left_ptr_str_values, compacted_values_less_equal);

    // Advance reading and storing pointers
    left_ptr_str_data += less_equal_count;
    left_ptr_data += vec_length;
    left_ptr_str_values += less_equal_count;
    left_ptr_values += vec_length;
    }

    // Finish last elements outside loop, using masking
    svbool_t pred = svwhilelt_b32_u32(0, end_ptr + vec_length - left_ptr_data);

    svuint32_t vec_data = svld1(pred, left_ptr_data);
    svfloat32_t vec_values = svld1(pred, left_ptr_values);

    svbool_t less_mask = svcmplt_n_f32(pred, vec_values, threshold_value);

    svuint32_t compacted_data_less_equal = svcompact(less_mask, vec_data);
    svfloat32_t compacted_values_less_equal = svcompact(less_mask, vec_values);

    size_t less_equal_count = svcntp_b32(pred, less_mask);
    svbool_t store_pred = svwhilelt_b32_u32(0, less_equal_count);
    svst1(store_pred, left_ptr_str_data, compacted_data_less_equal);
    svst1(store_pred, left_ptr_str_values, compacted_values_less_equal);

    left_ptr_str_data += less_equal_count;

    // Return the final partition index where elements < pivot are to the left
    return left_ptr_str_data - indices;
}
#endif

#endif